-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[MLIR][Vector]Add constraints to vector.shape_cast(constant) -> constant #147691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Mengmeng Sun (MengmSun) ChangesWe have the case that after
Our next pass is
and we found that's because a So we want to add the constraints that only when the element type of the source attribute and return type are the same it will return Full diff: https://github.com/llvm/llvm-project/pull/147691.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 214d2ba7e1b8e..5bbe6704aac48 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5922,10 +5922,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
return bcastOp.getSource();
}
- // shape_cast(constant) -> constant
+ // shape_cast(constant) -> constant,
+ // if element type of the source and result are the same
if (auto splatAttr =
- llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
- return splatAttr.reshape(getType());
+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
+ if (splatAttr.getElementType() == resultType.getElementType())
+ return splatAttr.reshape(getType());
+ }
// shape_cast(poison) -> poison
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8a9e27378df61..69da8a31d2c9b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1002,6 +1002,18 @@ func.func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> {
// -----
+// CHECK-LABEL: func @canonicalize_extract_shapecast_different_element_type
+func.func @canonicalize_extract_shapecast_different_element_type()->vector<12xi8> {
+ %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<12xf8E4M3FN>) : vector<12xi8>
+ // CHECK-NOT: vector.shape_cast
+ %1 = vector.shape_cast %0 : vector<12xi8> to vector<1x12xi8>
+ // CHECK-NOT: vector.extract
+ %2 = vector.extract %1[0] : vector<12xi8> from vector<1x12xi8>
+ return %2 : vector<12xi8>
+}
+
+// -----
+
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar
// CHECK: vector.broadcast
// CHECK-NOT: vector.shape_cast
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR #133988 moved the canonicalizer to a folder, that is probably what triggered your error.
I am not sure if
%0 = llvm.mlir.constant(dense<0.000000e+00> : vector<12xf8E4M3FN>) : vector<12xi8>
is valid. I can see the bitwidth of the element, and the number of elements is the same between the 2 vectors. But I'd have thought they must be identical. Maybe there is a verification missing in LogicalResult LLVM::ConstantOp::verify()
?
When I run
mlir-opt --verify-diagnostics playtime.mlir
on
func.func @canonicalize_extract_shapecast_different_element_type()->vector<12xi128> {
%0 = llvm.mlir.constant(dense<0.000000e+00> : vector<12xf8E4M3FN>) : vector<12xi128>
%1 = vector.shape_cast %0 : vector<12xi128> to vector<1x12xi128>
%2 = vector.extract %1[0] : vector<12xi128> from vector<1x12xi128>
return %2 : vector<12xi128>
}
I don't get an error either, but this example looks especially wrong because the number of bits is different between the element types.
The following question is unrelated but I've seen this popping up multiple times lately: why |
I think this is actually central to the PR! |
We have the case that after
ConvertToLLVMPass
it looks like:Our next pass is
Canonicalizer
. Several months ago everything went smoothly. However recently we met problem thatand we found that's because a
reshape
operation is added forvector.shape_cast(constant) -> constant
when callingShapeCastOp::fold()
in the Canonicalizer pass. This operation will fail if the element type of the source attribute and return type are different.So we want to add the constraints that only when the element type of the source attribute and return type are the same it will return
reshape
operation to make our case work as before and will not influence other cases.